﻿
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.IO;
using System.Reflection;
using CatEars.Core.Collections;

namespace CatEars.Core.Data
{
    /// <summary>
    /// Sql驱动管理器
    /// </summary>
    public class DbDriverManage
    {
        /// <summary>
        /// 忽略的文件名
        /// </summary>
        static HashSet<string> _ignoreFileName;

        /// <summary>
        /// 静态初始化
        /// </summary>
        static DbDriverManage()
        {
            //明确不能加载的组件在此登记
            _ignoreFileName = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
            _ignoreFileName.Add("SQLite.Designer.dll");
            _ignoreFileName.Add("System.Data.SQLite.EF6.dll");
            _ignoreFileName.Add("System.Data.SQLite.Linq.dll");
        }

        /// <summary>
        /// 默认构造函数
        /// </summary>
        public DbDriverManage()
        {
            //添加内置驱动
            AddDriver(System.Data.SqlClient.SqlClientFactory.Instance);
            AddDriver(System.Data.OleDb.OleDbFactory.Instance);
            AddDriver(System.Data.Odbc.OdbcFactory.Instance);
        }

        /// <summary>
        /// 构造函数，指定驱动文件夹
        /// </summary>
        /// <param name="strDirName">驱动文件夹</param>
        public DbDriverManage(string strDirName) : this()
        {
            DirectoryInfo dir = new DirectoryInfo(strDirName);
            if (dir.Exists)
            {
                foreach (var file in dir.GetFiles("*.dll"))
                {
                    if (_ignoreFileName.Contains(file.Name))
                    {
                        continue;
                    }
                    AddDriverFromFlie(file);
                }
            }
        }

        /// <summary>
        /// 驱动缓存
        /// </summary>
        Dictionary<string, DbProviderFactory> m_drivers
            = new Dictionary<string, DbProviderFactory>(StringComparer.OrdinalIgnoreCase);

        /// <summary>
        /// 支持的驱动
        /// </summary>
        public IEnumerable<string> DirverNames { get { return m_drivers.Keys; } }

        /// <summary>
        /// 根据驱动名称获取驱动
        /// </summary>
        /// <param name="strDriverName">驱动名称</param>
        /// <returns>驱动工厂</returns>
        public DbProviderFactory this[string strDriverName]
        {
            get
            {
                DbProviderFactory result;
                m_drivers.TryGetValue(strDriverName, out result);
                return result;
            }
        }

        /// <summary>
        /// DbProviderFactory的Type对象
        /// </summary>
        static readonly Type PROVIDER_TYPE = typeof(DbProviderFactory);
        
        /// <summary>
        /// 从文件加载驱动
        /// </summary>
        /// <param name="strFileFullName">要加载的完整路径</param>
        /// <returns>驱动名称集合</returns>
        public string[] AddDriverFromFlie(string strFileFullName)
        {
            FileInfo fi = new FileInfo(strFileFullName);
            return AddDriverFromFlie(fi);
        }

        /// <summary>
        /// 从文件加载驱动
        /// </summary>
        /// <param name="fileInfo">要加载的文件</param>
        /// <returns>驱动名称集合</returns>
        public string[] AddDriverFromFlie(FileInfo fileInfo)
        {
            if (fileInfo == null
                || !fileInfo.Exists)
            {
                return EmptyArray.StringArray;
            }
            List<string> result = new List<string>();
            try
            {
                Assembly ass = Assembly.LoadFile(fileInfo.FullName);
                foreach (var type in ass.GetTypes())
                {
                    try
                    {
                        if (!type.IsClass || !type.IsPublic || type.IsAbstract)
                        {
                            continue;
                        }
                        if (!type.IsSubclassOf(PROVIDER_TYPE))
                        {
                            continue;
                        }
                        var provider = GetFieldValue<DbProviderFactory>(type, "Instance");
                        if (provider == null)
                        {
                            provider = GetPropValue<DbProviderFactory>(type, "Instance");
                        }
                        if (provider == null)
                        {
                            continue;
                        }
                        string strDriverInfo = AddDriver(provider);
                        if (strDriverInfo != null)
                        {
                            result.Add(strDriverInfo);
                        }
                    }
                    catch
                    {
                        //忽略类型加载错误
                    }
                }
            }
            catch
            {
                //忽略程序集加载错误
            }
            return result.ToArray();
        }
        
        /// <summary>
        /// 添加驱动
        /// </summary>
        /// <param name="provider">DbProviderFactory对象</param>
        /// <returns>驱动名称</returns>
        public string AddDriver(DbProviderFactory provider)
        {
            return AddDriver(provider, null);
        }

        /// <summary>
        /// 添加驱动
        /// </summary>
        /// <param name="provider">DbProviderFactory对象</param>
        /// <param name="strName">自定义名称 null:自动生成</param>
        /// <returns>驱动名称</returns>
        public string AddDriver(DbProviderFactory provider, string strName)
        {
            string strDriverInfo = strName;
            if (string.IsNullOrEmpty(strDriverInfo))
            {
                strDriverInfo = provider.GetType().Name;

                //移除驱动名称中的Factory
                if (strDriverInfo.EndsWith("Factory", StringComparison.OrdinalIgnoreCase))
                {
                    strDriverInfo = strDriverInfo.Substring(0, strDriverInfo.Length - "Factory".Length);
                }
            }
            m_drivers[strDriverInfo] = provider;
            if (strDriverInfo.Equals("MysqlClient", StringComparison.OrdinalIgnoreCase))
            {
                //MysqlClient加一个别名
                m_drivers["Mysql"] = provider;
            }
            if (strDriverInfo.Equals("SqlClient", StringComparison.OrdinalIgnoreCase))
            {
                //SqlClient加一个别名
                m_drivers["Sqlserver"] = provider;
            }
            return strDriverInfo;
        }

        /// <summary>
        /// 获取指定静态字段的名称
        /// </summary>
        /// <param name="type">类型</param>
        /// <param name="strName">字段名</param>
        /// <returns>T对象</returns>
        private T GetFieldValue<T>(Type type, string strName) where T : class
        {
            T result = null;
            var field = type.GetField(strName, BindingFlags.Public | BindingFlags.Static);
            if (field != null)
            {
                result = field.GetValue(null) as T;
            }
            return result;
        }

        /// <summary>
        /// 获取指定静态属性的名称
        /// </summary>
        /// <param name="type">类型</param>
        /// <param name="strName">属性名</param>
        /// <returns>T对象</returns>
        private T GetPropValue<T>(Type type, string strName) where T : class
        {
            T result = null;
            var prop = type.GetProperty(strName, BindingFlags.Public | BindingFlags.Static);
            if (prop != null && prop.CanRead)
            {
                result = prop.GetValue(null, null) as T;
            }
            return result;
        }

        /// <summary>
        /// 重写ToString
        /// </summary>
        /// <returns>返回值</returns>
        public override string ToString()
        {
            try
            {
                return "已有驱动:" + string.Join(",", this.DirverNames);
            }
            catch
            {
                return base.ToString();
            }
        }
    }
}
